{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cfbd3efb",
   "metadata": {},
   "source": [
    "## basics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6da35c4f",
   "metadata": {},
   "source": [
    "- Spectral Normalization for Generative Adversarial Networks\n",
    "    - https://arxiv.org/abs/1802.05957\n",
    "- pytorch 的两个接口\n",
    "    - old：`torch.nn.utils.spectral_norm`\n",
    "    - new：`torch.nn.utils.parametrizations.spectral_norm`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "565b0f77",
   "metadata": {},
   "source": [
    "## 矩阵的谱范数"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6279cdf3",
   "metadata": {},
   "source": [
    "$$\n",
    "\\|A\\|_2 = \\max_{\\|x\\|\\neq 0}\\frac{\\|Ax\\|_2}{\\|x\\|_2}=\\sqrt{\\lambda_\\max(A^TA)}=\\sigma_\\max(A)\n",
    "$$\n",
    "\n",
    "- A的谱范数 = A的最大奇异值 = A^T·A的最大特征值的平方根\n",
    "- The spectral norm (also know as Induced 2-norm) is the maximum singular value of a matrix. Intuitively, you can think of it as the maximum 'scale', by which the matrix can 'stretch' a vector.\n",
    "\n",
    "- The maximum singular value is the square root of the maximum eigenvalue or the maximum eigenvalue if the matrix is symmetric/hermitian"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2eb8e42",
   "metadata": {},
   "source": [
    "- 两种计算方法\n",
    "    - svd 分解\n",
    "    - 幂迭代法 (Power Iteration Method)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "44dc6151",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T12:49:12.302552Z",
     "start_time": "2023-08-02T12:49:12.290800Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1, 1, 1, 2],\n",
       "       [1, 0, 0, 1],\n",
       "       [4, 4, 3, 0],\n",
       "       [2, 1, 4, 4],\n",
       "       [1, 2, 1, 4]])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "A = np.random.randint(0, 5, (5, 4))\n",
    "A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "b17d7e1c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:13:18.176717Z",
     "start_time": "2023-08-02T13:13:18.166111Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.3655181 ],\n",
       "       [1.86415228],\n",
       "       [0.98671532],\n",
       "       [0.14176211]])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = np.random.randn(4, 1)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "7df08223",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:13:19.214432Z",
     "start_time": "2023-08-02T13:13:19.194105Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.14531369176711"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.linalg.norm(x, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "40d36479",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:13:20.270393Z",
     "start_time": "2023-08-02T13:13:20.260167Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15.363845704020635"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.linalg.norm(A.dot(x), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "9a011186",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:13:21.927990Z",
     "start_time": "2023-08-02T13:13:21.916578Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[-0.27555358,  0.15562145,  0.16779797, -0.10907471, -0.92725333],\n",
       "        [-0.11022474,  0.08816295,  0.11996284, -0.96553443,  0.18283868],\n",
       "        [-0.57781611, -0.80835168,  0.09880305,  0.01432065,  0.05223962],\n",
       "        [-0.61963188,  0.36789398, -0.68201428,  0.04185011,  0.11753915],\n",
       "        [-0.44057418,  0.42335663,  0.6946562 ,  0.23214104,  0.30037784]]),\n",
       " array([9.17512311, 4.43478733, 2.08144212, 0.90408843]),\n",
       " array([[-0.47703782, -0.44550882, -0.5371158 , -0.53428778],\n",
       "        [-0.41275435, -0.42012793, -0.08444623,  0.80373827],\n",
       "        [ 0.00653425,  0.610302  , -0.75389739,  0.24316146],\n",
       "        [-0.77590339,  0.50253943,  0.368801  , -0.09702511]]))"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.linalg.svd(A)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8c09b63",
   "metadata": {},
   "source": [
    "## pytorch api"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "89ce2ecf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:00:00.569859Z",
     "start_time": "2023-08-02T13:00:00.563337Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch \n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "0bdb1931",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:00:01.827143Z",
     "start_time": "2023-08-02T13:00:01.820223Z"
    }
   },
   "outputs": [],
   "source": [
    "m = nn.Linear(5, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "6dd5ec7a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:00:07.952560Z",
     "start_time": "2023-08-02T13:00:07.941066Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2564,  0.0280, -0.3555, -0.1643, -0.3643],\n",
       "        [-0.0410, -0.4005,  0.2989, -0.2165,  0.2943],\n",
       "        [ 0.2819,  0.1352,  0.4254,  0.1120,  0.1435],\n",
       "        [-0.2132, -0.3698, -0.2455, -0.0125, -0.3623]],\n",
       "       grad_fn=<CloneBackward0>)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W = m.weight.clone()\n",
    "W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "315575b5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:00:16.693750Z",
     "start_time": "2023-08-02T13:00:16.676071Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2870,  0.0313, -0.3978, -0.1838, -0.4076],\n",
       "        [-0.0459, -0.4482,  0.3345, -0.2423,  0.3293],\n",
       "        [ 0.3154,  0.1513,  0.4760,  0.1253,  0.1606],\n",
       "        [-0.2386, -0.4139, -0.2747, -0.0140, -0.4054]], grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m_sn = nn.utils.parametrizations.spectral_norm(m)\n",
    "m_sn.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "f9a04e60",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:00:32.214899Z",
     "start_time": "2023-08-02T13:00:32.202596Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.89362526, 0.6472171 , 0.37808868, 0.25057557], dtype=float32)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "U, s, V = np.linalg.svd(W.detach().numpy())\n",
    "s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "06ef3058",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-08-02T13:00:41.337483Z",
     "start_time": "2023-08-02T13:00:41.324800Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2870,  0.0313, -0.3978, -0.1838, -0.4076],\n",
       "        [-0.0459, -0.4482,  0.3345, -0.2423,  0.3293],\n",
       "        [ 0.3154,  0.1513,  0.4760,  0.1253,  0.1606],\n",
       "        [-0.2386, -0.4139, -0.2747, -0.0140, -0.4054]], grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W/s[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d5c4fc6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
